#
#  Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#
import json
import logging
import time
import uuid
from html import escape
from typing import Any

from flask import make_response, request
from flask_login import current_user, login_required
from google_auth_oauthlib.flow import Flow

from api.db import InputType
from api.db.services.connector_service import ConnectorService, SyncLogsService
from api.utils.api_utils import get_data_error_result, get_json_result, validate_request
from common.constants import RetCode, TaskStatus
from common.data_source.config import GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI, DocumentSource
from common.data_source.google_util.constant import GOOGLE_DRIVE_WEB_OAUTH_POPUP_TEMPLATE, GOOGLE_SCOPES
from common.misc_utils import get_uuid
from rag.utils.redis_conn import REDIS_CONN


@manager.route("/set", methods=["POST"])  # noqa: F821
@login_required
def set_connector():
    req = request.json
    if req.get("id"):
        conn = {fld: req[fld] for fld in ["prune_freq", "refresh_freq", "config", "timeout_secs"] if fld in req}
        ConnectorService.update_by_id(req["id"], conn)
    else:
        req["id"] = get_uuid()
        conn = {
            "id": req["id"],
            "tenant_id": current_user.id,
            "name": req["name"],
            "source": req["source"],
            "input_type": InputType.POLL,
            "config": req["config"],
            "refresh_freq": int(req.get("refresh_freq", 30)),
            "prune_freq": int(req.get("prune_freq", 720)),
            "timeout_secs": int(req.get("timeout_secs", 60 * 29)),
            "status": TaskStatus.SCHEDULE,
        }
        ConnectorService.save(**conn)

    time.sleep(1)
    e, conn = ConnectorService.get_by_id(req["id"])

    return get_json_result(data=conn.to_dict())


@manager.route("/list", methods=["GET"])  # noqa: F821
@login_required
def list_connector():
    return get_json_result(data=ConnectorService.list(current_user.id))


@manager.route("/<connector_id>", methods=["GET"])  # noqa: F821
@login_required
def get_connector(connector_id):
    e, conn = ConnectorService.get_by_id(connector_id)
    if not e:
        return get_data_error_result(message="Can't find this Connector!")
    return get_json_result(data=conn.to_dict())


@manager.route("/<connector_id>/logs", methods=["GET"])  # noqa: F821
@login_required
def list_logs(connector_id):
    req = request.args.to_dict(flat=True)
    arr, total = SyncLogsService.list_sync_tasks(connector_id, int(req.get("page", 1)), int(req.get("page_size", 15)))
    return get_json_result(data={"total": total, "logs": arr})


@manager.route("/<connector_id>/resume", methods=["PUT"])  # noqa: F821
@login_required
def resume(connector_id):
    req = request.json
    if req.get("resume"):
        ConnectorService.resume(connector_id, TaskStatus.SCHEDULE)
    else:
        ConnectorService.resume(connector_id, TaskStatus.CANCEL)
    return get_json_result(data=True)


@manager.route("/<connector_id>/rebuild", methods=["PUT"])  # noqa: F821
@login_required
@validate_request("kb_id")
def rebuild(connector_id):
    req = request.json
    err = ConnectorService.rebuild(req["kb_id"], connector_id, current_user.id)
    if err:
        return get_json_result(data=False, message=err, code=RetCode.SERVER_ERROR)
    return get_json_result(data=True)


@manager.route("/<connector_id>/rm", methods=["POST"])  # noqa: F821
@login_required
def rm_connector(connector_id):
    ConnectorService.resume(connector_id, TaskStatus.CANCEL)
    ConnectorService.delete_by_id(connector_id)
    return get_json_result(data=True)


GOOGLE_WEB_FLOW_STATE_PREFIX = "google_drive_web_flow_state"
GOOGLE_WEB_FLOW_RESULT_PREFIX = "google_drive_web_flow_result"
WEB_FLOW_TTL_SECS = 15 * 60


def _web_state_cache_key(flow_id: str) -> str:
    return f"{GOOGLE_WEB_FLOW_STATE_PREFIX}:{flow_id}"


def _web_result_cache_key(flow_id: str) -> str:
    return f"{GOOGLE_WEB_FLOW_RESULT_PREFIX}:{flow_id}"


def _load_credentials(payload: str | dict[str, Any]) -> dict[str, Any]:
    if isinstance(payload, dict):
        return payload
    try:
        return json.loads(payload)
    except json.JSONDecodeError as exc:  # pragma: no cover - defensive
        raise ValueError("Invalid Google credentials JSON.") from exc


def _get_web_client_config(credentials: dict[str, Any]) -> dict[str, Any]:
    web_section = credentials.get("web")
    if not isinstance(web_section, dict):
        raise ValueError("Google OAuth JSON must include a 'web' client configuration to use browser-based authorization.")
    return {"web": web_section}


def _render_web_oauth_popup(flow_id: str, success: bool, message: str):
    status = "success" if success else "error"
    auto_close = "window.close();" if success else ""
    escaped_message = escape(message)
    payload_json = json.dumps(
        {
            "type": "ragflow-google-drive-oauth",
            "status": status,
            "flowId": flow_id or "",
            "message": message,
        }
    )
    html = GOOGLE_DRIVE_WEB_OAUTH_POPUP_TEMPLATE.format(
        heading="Authorization complete" if success else "Authorization failed",
        message=escaped_message,
        payload_json=payload_json,
        auto_close=auto_close,
    )
    response = make_response(html, 200)
    response.headers["Content-Type"] = "text/html; charset=utf-8"
    return response


@manager.route("/google-drive/oauth/web/start", methods=["POST"])  # noqa: F821
@login_required
@validate_request("credentials")
def start_google_drive_web_oauth():
    if not GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI:
        return get_json_result(
            code=RetCode.SERVER_ERROR,
            message="Google Drive OAuth redirect URI is not configured on the server.",
        )

    req = request.json or {}
    raw_credentials = req.get("credentials", "")
    try:
        credentials = _load_credentials(raw_credentials)
    except ValueError as exc:
        return get_json_result(code=RetCode.ARGUMENT_ERROR, message=str(exc))

    if credentials.get("refresh_token"):
        return get_json_result(
            code=RetCode.ARGUMENT_ERROR,
            message="Uploaded credentials already include a refresh token.",
        )

    try:
        client_config = _get_web_client_config(credentials)
    except ValueError as exc:
        return get_json_result(code=RetCode.ARGUMENT_ERROR, message=str(exc))

    flow_id = str(uuid.uuid4())
    try:
        flow = Flow.from_client_config(client_config, scopes=GOOGLE_SCOPES[DocumentSource.GOOGLE_DRIVE])
        flow.redirect_uri = GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI
        authorization_url, _ = flow.authorization_url(
            access_type="offline",
            include_granted_scopes="true",
            prompt="consent",
            state=flow_id,
        )
    except Exception as exc:  # pragma: no cover - defensive
        logging.exception("Failed to create Google OAuth flow: %s", exc)
        return get_json_result(
            code=RetCode.SERVER_ERROR,
            message="Failed to initialize Google OAuth flow. Please verify the uploaded client configuration.",
        )

    cache_payload = {
        "user_id": current_user.id,
        "client_config": client_config,
        "created_at": int(time.time()),
    }
    REDIS_CONN.set_obj(_web_state_cache_key(flow_id), cache_payload, WEB_FLOW_TTL_SECS)

    return get_json_result(
        data={
            "flow_id": flow_id,
            "authorization_url": authorization_url,
            "expires_in": WEB_FLOW_TTL_SECS,
        }
    )


@manager.route("/google-drive/oauth/web/callback", methods=["GET"])  # noqa: F821
def google_drive_web_oauth_callback():
    state_id = request.args.get("state")
    error = request.args.get("error")
    error_description = request.args.get("error_description") or error

    if not state_id:
        return _render_web_oauth_popup("", False, "Missing OAuth state parameter.")

    state_cache = REDIS_CONN.get(_web_state_cache_key(state_id))
    if not state_cache:
        return _render_web_oauth_popup(state_id, False, "Authorization session expired. Please restart from the main window.")

    state_obj = json.loads(state_cache)
    client_config = state_obj.get("client_config")
    if not client_config:
        REDIS_CONN.delete(_web_state_cache_key(state_id))
        return _render_web_oauth_popup(state_id, False, "Authorization session was invalid. Please retry.")

    if error:
        REDIS_CONN.delete(_web_state_cache_key(state_id))
        return _render_web_oauth_popup(state_id, False, error_description or "Authorization was cancelled.")

    code = request.args.get("code")
    if not code:
        return _render_web_oauth_popup(state_id, False, "Missing authorization code from Google.")

    try:
        flow = Flow.from_client_config(client_config, scopes=GOOGLE_SCOPES[DocumentSource.GOOGLE_DRIVE])
        flow.redirect_uri = GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI
        flow.fetch_token(code=code)
    except Exception as exc:  # pragma: no cover - defensive
        logging.exception("Failed to exchange Google OAuth code: %s", exc)
        REDIS_CONN.delete(_web_state_cache_key(state_id))
        return _render_web_oauth_popup(state_id, False, "Failed to exchange tokens with Google. Please retry.")

    creds_json = flow.credentials.to_json()
    result_payload = {
        "user_id": state_obj.get("user_id"),
        "credentials": creds_json,
    }
    REDIS_CONN.set_obj(_web_result_cache_key(state_id), result_payload, WEB_FLOW_TTL_SECS)
    REDIS_CONN.delete(_web_state_cache_key(state_id))

    return _render_web_oauth_popup(state_id, True, "Authorization completed successfully.")


@manager.route("/google-drive/oauth/web/result", methods=["POST"])  # noqa: F821
@login_required
@validate_request("flow_id")
def poll_google_drive_web_result():
    req = request.json or {}
    flow_id = req.get("flow_id")
    cache_raw = REDIS_CONN.get(_web_result_cache_key(flow_id))
    if not cache_raw:
        return get_json_result(code=RetCode.RUNNING, message="Authorization is still pending.")

    result = json.loads(cache_raw)
    if result.get("user_id") != current_user.id:
        return get_json_result(code=RetCode.PERMISSION_ERROR, message="You are not allowed to access this authorization result.")

    REDIS_CONN.delete(_web_result_cache_key(flow_id))
    return get_json_result(data={"credentials": result.get("credentials")})
